# coding:utf-8
# Author : hiicy redldw
# Date : 2019/04/17
import os
import keras.applications.mobilenetv2
from sklearn.model_selection import train_test_split
import tensorflow as tf
import matplotlib.pyplot as plt

def Conv2dBn(input, kernel_size=(3, 3), nfilter=32, stride=(1, 1), padding="valid", activation='relu', name=None):
    """
    卷积层
    """
    x = tf.keras.layers.Conv2D(nfilter, kernel_size, stride, padding, activation=activation, name=name)(input)
    x = tf.keras.layers.BatchNormalization(axis=3)(x)  # axis=3 设置特征轴
    return x


def Pool(input, pool_size=(3, 3), stride=(1, 1), pool_type='avg',padding='valid',name=None):
    if str.lower(pool_type) == "avg":
        x = tf.keras.layers.AveragePooling2D(pool_size, stride,padding=padding, name=name)(input)
    elif str.lower(pool_type) == 'max':
        x = tf.keras.layers.MaxPooling2D(pool_size, stride,padding=padding, name=name)(input)
    return x


class MyInceptionV3(object):
    def __init__(self, shape, classes=3,ckptpath=""):
        self.shape = shape
        self.classes = classes
        self.layers = None
        self.model = None
        self.ckptPath = ckptpath
    def build(self):
        base_input = tf.keras.layers.Input(shape=self.shape) # 299*299*3
        with tf.name_scope("first") as first_scope:
            x = Conv2dBn(base_input, stride=(2, 2)) # 149*149*32
        with tf.name_scope("second") as second_scope:
            x = Conv2dBn(x)#147*147*32
        with tf.name_scope("third") as third_scope:
            x = Conv2dBn(x, nfilter=64,padding='same') #147*147*64
        with tf.name_scope("fourth") as fourth_scope:
            x = Pool(x, pool_type="max", stride=(2, 2)) #73*73*64
        with tf.name_scope("fifth") as fifth_scope:
            x = Conv2dBn(x, kernel_size=(1, 1), nfilter=80) #73*73*80
        with tf.name_scope("sixth") as sixth_scope:
            x = Conv2dBn(x, nfilter=192) # 71*71*192
        with tf.name_scope("seventh") as seventh_scope:
            x = Pool(x, pool_type="max", stride=(2, 2))  # 35*35*192
        with tf.name_scope("block1") as block1:
            with tf.name_scope("module1") as module1:
                # FIXME: 利用branch一条一条构通
                branch1 = Conv2dBn(x, kernel_size=(1, 1), nfilter=64) #35*35*64

                branch2 = Conv2dBn(x, kernel_size=(1, 1), nfilter=48) #35*35*48
                branch2 = Conv2dBn(branch2, kernel_size=(5, 5), nfilter=64,padding='same')  #35*35*64

                branch3 = Conv2dBn(x, kernel_size=(1, 1), nfilter=64)#35*35*64  # 其实可以借用branch1，但这里说明清楚就多写一个
                branch3 = Conv2dBn(branch3, kernel_size=(3, 3), nfilter=96,padding='same') #35*35*96
                branch3 = Conv2dBn(branch3, kernel_size=(3, 3), nfilter=96,padding='same') #35*35*96

                branch4 = Pool(x, pool_type='avg',padding='same')#35*35*192
                branch4 = Conv2dBn(branch4, kernel_size=(1,1), nfilter=32) #35*35*32
                x = tf.keras.layers.concatenate(
                    [branch1, branch2, branch3, branch4],
                    axis=3,
                )  # 35*35*256
                print(x)
            with tf.name_scope("module2") as module2:
                branch1 = Conv2dBn(x, kernel_size=(1, 1), nfilter=64)  #35*35*64

                branch2 = Conv2dBn(x, kernel_size=(1, 1), nfilter=48) # 35*35*48
                branch2 = Conv2dBn(branch2, kernel_size=(5, 5), nfilter=64,padding='same')  #35*35*64

                branch3 = Conv2dBn(x, kernel_size=(1, 1), nfilter=64)#35*35*64 其实可以借用blockModule1_1，但这里说明清楚就多写一个
                branch3 = Conv2dBn(branch3, kernel_size=(3, 3), nfilter=96,padding='same')#35*35*96
                branch3 = Conv2dBn(branch3, kernel_size=(3, 3), nfilter=96,padding='same')#35*35*96

                branch4 = Pool(x, pool_type='avg',padding='same')  # 35*35*256
                branch4 = Conv2dBn(branch4, kernel_size=(1, 1), nfilter=64)  #35*35*64
                x = tf.keras.layers.concatenate(
                    [branch1, branch2, branch3, branch4],
                    axis=3
                )  # 35*35*288
            with tf.name_scope("module3") as module3:
                branch1 = Conv2dBn(x, kernel_size=(1, 1), nfilter=64) #35*35*64

                branch2 = Conv2dBn(x, kernel_size=(1, 1), nfilter=48) # 35*35*48
                branch2 = Conv2dBn(branch2, kernel_size=(5, 5), nfilter=64,padding='same') #35*35*64

                branch3 = Conv2dBn(x, kernel_size=(1, 1), nfilter=64)# 35*35*64 其实可以借用blockModule1_1，但这里说明清楚就多写一个
                branch3 = Conv2dBn(branch3, kernel_size=(3, 3), nfilter=96,padding='same') #35*35*96
                branch3 = Conv2dBn(branch3, kernel_size=(3, 3), nfilter=96,padding='same') # 35*35*96
                branch4 = Pool(x, pool_type='avg',padding='same') # 35*35*288
                branch4 = Conv2dBn(branch4, kernel_size=(1, 1), nfilter=64) # 35*35*64
                x = tf.keras.layers.concatenate(
                    [branch1, branch2, branch3, branch4],
                    axis=3
                ) # 35*35*288

        with tf.name_scope("block2") as block2:
            with tf.name_scope("module1") as module1:
                branch1 = Conv2dBn(x, nfilter=384, stride=(2, 2)) # 17*17*384

                branch2 = Conv2dBn(x, kernel_size=(1, 1), nfilter=64)  # 35*35*64
                branch2 = Conv2dBn(x, nfilter=96,padding='same')  # 35*35*96
                branch2 = Conv2dBn(x, nfilter=96, stride=(2, 2))  # 17*17*96

                branch3 = Pool(x,pool_type='max', stride=(2, 2)) # 17*17*288

                x = tf.keras.layers.concatenate(
                    [branch1, branch2, branch3],
                    axis=3,
                )  # 17*17*768
            with tf.name_scope('module2') as module2:
                branch1 = Conv2dBn(x, kernel_size=(1, 1), nfilter=192)  # 17*17*192

                branch2 = Conv2dBn(x, kernel_size=(1, 1), nfilter=128)  # 17*17*128
                branch2 = Conv2dBn(branch2, kernel_size=(1, 7), nfilter=128,padding='same')  # 17*17*128
                branch2 = Conv2dBn(branch2, kernel_size=(7, 1), nfilter=192,padding='same')  # 17*17*192

                branch3 = Conv2dBn(x, kernel_size=(1, 1), nfilter=128)  # 17*17*768
                # pad:kernel_size // 2
                branch3 = Conv2dBn(branch3, kernel_size=(7, 1), nfilter=128,padding='same') # 17*17*128
                branch3 = Conv2dBn(branch3, kernel_size=(1, 7), nfilter=128,padding='same') # 17*17*128
                branch3 = Conv2dBn(branch3, kernel_size=(7, 1), nfilter=128,padding='same') # 17*17*128
                branch3 = Conv2dBn(branch3, kernel_size=(1, 7), nfilter=192,padding='same') # 17*17*192

                branch4 = Pool(x, pool_type='avg',padding='same') # 17*17*768
                branch4 = Conv2dBn(branch4, kernel_size=(1, 1), nfilter=192)  # 17*17*192
                x = tf.keras.layers.concatenate(
                    [branch1, branch2, branch3,branch4],
                    axis=3,
                )  # 17*17*768
            for i in [3,4]:
                with tf.name_scope(f'module{i}'):
                    branch1 = Conv2dBn(x, kernel_size=(1, 1), nfilter=192) # 17*17*192

                    branch2 = Conv2dBn(x, kernel_size=(1, 1), nfilter=160) # 17*17*160
                    branch2 = Conv2dBn(branch2, kernel_size=(1, 7), nfilter=160,padding='same') # 17*17*160
                    branch2 = Conv2dBn(branch2, kernel_size=(7, 1), nfilter=192,padding='same') # 17*17*192

                    branch3 = Conv2dBn(x, kernel_size=(1, 1), nfilter=160) # 17*17*160
                    branch3 = Conv2dBn(branch3, kernel_size=(7, 1), nfilter=160,padding='same') # 17*17*160
                    branch3 = Conv2dBn(branch3, kernel_size=(1, 7), nfilter=160,padding='same') # 17*17*160
                    branch3 = Conv2dBn(branch3, kernel_size=(7, 1), nfilter=160,padding='same') # 17*17*160
                    branch3 = Conv2dBn(branch3, kernel_size=(1, 7), nfilter=192,padding='same') # 17*17*192
                    branch4 = Pool(x,pool_type='avg',padding='same') # 17*17*768
                    branch4 = Conv2dBn(branch4, kernel_size=(1, 1), nfilter=192) # 17*17*192
                    x = tf.keras.layers.concatenate(
                        [branch1, branch2, branch3, branch4],
                        axis=3
                    )  # 17*17*768
            with tf.name_scope("module5") as module5:
                branch1 = Conv2dBn(x, kernel_size=(1, 1), nfilter=192) # 17*17*192
                branch2 = Conv2dBn(x, kernel_size=(1, 1), nfilter=192) # 17*17*192
                branch2 = Conv2dBn(branch2, kernel_size=(1, 7), nfilter=192,padding='same')# 17*17*192
                branch2 = Conv2dBn(branch2, kernel_size=(7, 1), nfilter=192,padding='same')# 17*17*192

                branch3 = Conv2dBn(x, kernel_size=(1, 1), nfilter=192)# 17*17*192
                branch3 = Conv2dBn(branch3, kernel_size=(7, 1), nfilter=192,padding='same')# 17*17*192
                branch3 = Conv2dBn(branch3, kernel_size=(1, 7), nfilter=192,padding='same')# 17*17*192
                branch3 = Conv2dBn(branch3, kernel_size=(7, 1), nfilter=192,padding='same')# 17*17*192
                branch3 = Conv2dBn(branch3, kernel_size=(1, 7), nfilter=192,padding='same')# 17*17*192
                branch4 = Pool(x,pool_type='avg',padding='same') # 17*17*192
                branch4 = Conv2dBn(branch4, kernel_size=(1, 1), nfilter=192) # 17*17*192
                x = tf.keras.layers.concatenate(
                    [branch1, branch2, branch3, branch4],
                    axis=3
                ) # 17*17*768
        with tf.name_scope("block3") as block3:
            with tf.name_scope('module1'):
                branch1 = Conv2dBn(x, kernel_size=(1, 1), nfilter=192) # 17*17*192
                branch1 = Conv2dBn(branch1, nfilter=320, stride=(2, 2)) # 8*8*320

                branch2 = Conv2dBn(x, kernel_size=(1, 1), nfilter=192) # 17*17*192
                branch2 = Conv2dBn(branch2, kernel_size=(1, 7), nfilter=192,padding='same')  # 17*17*192
                branch2 = Conv2dBn(branch2, kernel_size=(7, 1), nfilter=192,padding='same') # 17*17*192
                branch2 = Conv2dBn(branch2, nfilter=192, stride=(2, 2)) # 8*8*192
                branch3 = Pool(x,pool_type='max',stride=(2,2)) # 8*8*768
                x = tf.keras.layers.concatenate(
                    [branch1,branch2,branch3],axis=3
                ) #8*8*1280
            for i in [2,3]:
                with tf.name_scope(f'module{i}'):
                    branch1 = Conv2dBn(x, kernel_size=(1, 1), nfilter=320) # 8*8*320

                    branch2 = Conv2dBn(x, kernel_size=(1, 1), nfilter=384) # 8*8*384
                    branch2_1 = Conv2dBn(branch2, kernel_size=(1, 3), nfilter=384,padding='same') # 8*8*384
                    branch2_2 = Conv2dBn(branch2, kernel_size=(3, 1), nfilter=384,padding='same') # 8*8*384
                    branch2 = tf.keras.layers.concatenate([branch2_1,branch2_2],axis=3) #8*8*768

                    branch3 = Conv2dBn(x, kernel_size=(1, 1), nfilter=448) # 8*8*448
                    branch3 = Conv2dBn(branch3,nfilter=384,padding='same') #8*8*384
                    branch3_1 = Conv2dBn(branch3, kernel_size=(1, 3), nfilter=384,padding='same') # 8*8*384
                    branch3_2 = Conv2dBn(branch3, kernel_size=(3, 1), nfilter=384,padding='same') # 8*8*384
                    branch3 = tf.keras.layers.concatenate([branch3_1,branch3_2],axis=3) # 8*8*768

                    branch4 = Pool(x,pool_type='avg',padding='same') #6*6*1280
                    branch4 = Conv2dBn(branch4, kernel_size=(1, 1), nfilter=192) # 8*8*192

                    x = tf.keras.layers.concatenate(
                        [branch1,branch2,branch3,branch4],
                        axis=3
                    ) # 8*8*2048
        with tf.name_scope('output'):
            x = tf.keras.layers.GlobalAveragePooling2D()(x) # 2048
            pred = tf.keras.layers.Dense(self.classes,activation='softmax')(x) # 8

        self.model = tf.keras.Model(inputs=base_input,outputs=pred)
        optimizer = tf.keras.optimizers.RMSprop()
        self.model.compile(loss="categorical_crossentropy",
                           metrics=['acc'],
                           optimizer=optimizer)
        print(self.model.summary())
        return

    def predict(self,ModelPath):
        if self.model is None:
            self.model = tf.keras.models.load_model(ModelPath)

    def train(self,x,y,batch=32,epoch=10,is_save=True):
        assert self.model is not None
        train_X, test_X, train_y, test_y = train_test_split(x,
                                                            y,
                                                            test_size=0.2,
                                                            random_state=20)
        self.ckptDir = os.path.dirname(self.ckptPath)
        # 保存检查点
        # cpCallBack = tf.keras.callbacks.ModelCheckpoint(self.ckptPath,
        #                                                 save_weights_only=False,
        #                                                 period=6)
        # FIXME: 用生成器方式
        history = self.model.fit(x,y,batch,epoch,verbose=1,
                                 shuffle=True,validation_split=0.2)
        predictions = self.model.predict(test_X, verbose=0)

        loss, accuracy = self.model.evaluate(test_X,test_y, batch_size=batch)
        print('test loss: ', loss)
        print('test accuracy: ', accuracy)

        self._plotLoss(history)
        if is_save:
            self.model.save(self.ckptPath)

    def _plotLoss(self,history):
        acc = history.history['acc']
        val_acc = history.history['val_acc']
        loss = history.history['loss']
        val_loss = history.history['val_loss']
        x = [i+1 for i in range(len(loss))]
        y = [i+1 for i in range(len(val_loss))]
        plt.clf()
        fig:plt.Figure = plt.figure()
        ax1 = fig.add_subplot(121)
        plt.plot(x,loss,marker='o',c='c',mc='g',lw=2,label="训练损失")
        plt.plot(y,val_loss,marker='o',c='r',mc='#D64474',label="验证损失")
        plt.sca(ax1)

        ax2 = fig.add_subplot(122)
        plt.plot(x, acc, marker='o', c='c', mc='g', lw=2, label="训练精度")
        plt.plot(y, val_acc, marker='o', c='r', mc='#D64474', label="验证精度")
        plt.sca(ax2)
        plt.savefig("imshow.jpg")

    def __str__(self):
        return "i love you ldw!"


myInceptionV3 = MyInceptionV3(shape=(299,299,3),classes=8)
print(myInceptionV3.build())
